import copy
import numpy as np 
from PIL import Image
import torchvision
from torchvision import transforms
from torchvision import datasets as vision_datasets
from torch.utils.data import Dataset 

from .rand_aug import RandAugment


norm_mean_std_dict = {
    'cifar100': [(0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)],
    'clothing1m': [(0.6959, 0.6537, 0.6371),(0.3113, 0.3192, 0.3214)],
}


def get_img_transform(img_size=32, 
                      crop_ratio=0.875, 
                      is_train=True,
                      resize='rpn',
                      autoaug='randaug',
                      rand_erase=True,
                      extra_aug=True,
                      norm_mean_std=[(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]):

    if not is_train:
        transform_list = []
        if resize == 'resize_crop':
            transform_list.extend([
                transforms.Resize((int(img_size / crop_ratio), int(img_size / crop_ratio))),
                transforms.CenterCrop((img_size, img_size)),
            ])
        else:
            transform_list.append(transforms.Resize((img_size, img_size)))
            
        transfrom = transforms.Compose([
            *transform_list,
            transforms.ToTensor(),
            transforms.Normalize(*norm_mean_std),
        ])
        return transfrom
    
    transform_list = []
    if resize == 'rpc':
        transform_list.append(transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)))
    elif resize == 'resize_rpc':
        transform_list.append(transforms.Resize((int(img_size / crop_ratio), int(img_size / crop_ratio))))
        transform_list.append(transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)))
    elif resize == 'resize_crop':
        transform_list.append(transforms.Resize((int(img_size / crop_ratio), int(img_size / crop_ratio))))
        transform_list.append(transforms.RandomCrop((img_size, img_size)))
    elif resize == 'resize_crop_pad':
        transform_list.append(transforms.Resize((img_size, img_size)))
        transform_list.append(transforms.RandomCrop((img_size, img_size), padding=int(img_size * (1 - crop_ratio)), padding_mode='reflect'))
    transform_list.append(transforms.RandomHorizontalFlip())
    
    if autoaug == 'randaug':
        transform_list.append(RandAugment(3, 5))
        rand_erase = False
    elif autoaug == 'autoaug_cifar':
        transform_list.append(transforms.AutoAugment(transforms. AutoAugmentPolicy.CIFAR10))  
    elif autoaug == 'autoaug':
        transform_list.append(transforms.AutoAugment()) 
    elif autoaug is None:
        rand_erase = False
    else:
        raise NotImplementedError
    
    transform_list.extend([
        transforms.ToTensor(),
        transforms.Normalize(*norm_mean_std),
    ])
    

    if rand_erase and autoaug != 'randaug' and autoaug is not None:
        # transform_list.append(CutoutDefault(scale=cutout))
        transform_list.append(transforms.RandomErasing())
    
    print(transform_list)
    
    transform = transforms.Compose(transform_list)
    return transform


class ImgBaseDataset(Dataset):
    def __init__(self, data_name, data, targets, is_train=True, num_classes=10, 
                 img_size=32, crop_ratio=0.875, autoaug='randaug', resize='rpc',
                 return_target=False, return_keys=['x_lb', 'y_lb']):
        super(ImgBaseDataset, self).__init__()

        self.data_name = data_name
        self.data = data 
        self.targets = targets
        self.num_classes = num_classes
        self.return_target = return_target
        self.return_keys = return_keys
        self.transform = get_img_transform(img_size, crop_ratio, is_train=is_train, resize=resize, autoaug=autoaug, norm_mean_std=norm_mean_std_dict.get(data_name, [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]))
    
    def __getitem__(self, index):
        data, target = self.data[index], self.targets[index]
        if isinstance(data, str):
            data = Image.open(data).convert('RGB')   
        else:
            data = Image.fromarray(data)
        data_aug = self.transform(data)
        return_items = [data_aug, target]
        return_dict = {k:v for k,v in zip(self.return_keys, return_items)}
        return return_dict

    def __len__(self):
        return len(self.data)



class ImgTwoViewBaseDataset(ImgBaseDataset):
    def __init__(self, data_name, data, targets, is_train=True, num_classes=10, 
                 img_size=32, crop_ratio=0.875, autoaug='randaug', resize='rpc',
                 return_target=False, return_keys=['x_ulb_w', 'x_ulb_s', 'y_ulb']):
        super().__init__(data_name, data, targets, is_train, num_classes, img_size, crop_ratio, None, resize, return_target, return_keys)
        self.strong_transform = get_img_transform(img_size, crop_ratio, is_train=is_train, resize=resize, autoaug=autoaug, norm_mean_std=norm_mean_std_dict.get(data_name, [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]))
    
    def __getitem__(self, index):
        data, target = self.data[index], self.targets[index]
        if isinstance(data, str):
            data = Image.open(data).convert('RGB')   
        else:
            data = Image.fromarray(data)
        data_aug_w = self.transform(data)
        data_aug_s = self.strong_transform(data)
        return_items = [data_aug_w, data_aug_s]
        if self.return_target:
            return_items.append(target)
        return_dict = {k:v for k,v in zip(self.return_keys, return_items)}
        return return_dict
    
